package me.yricky.kaguya.base.lu

import me.yricky.kaguya.base.ExpressionWire
import me.yricky.kaguya.base.NamedSignal
import me.yricky.kaguya.base.Signal
import me.yricky.kaguya.base.Wire

/**
 * 表达式逻辑单元，这个单元的逻辑可以用表达式表示，例如与运算、或运算、位移运算等。表达式逻辑单元只有一个类型为[ExpressionWire]的输出信号
 */
sealed class ExpressionLogicUnit: LogicUnit(){
    abstract val wire: ExpressionWire

    override val outputSignals: List<Wire> get() = listOf(wire)
}

class Inv(
    val input: Signal,
) : ExpressionLogicUnit() {
    override val wire by lazy {
        ExpressionWire("Inv(${input})",input.width, this)
    }

    override val inputSignals: List<Signal> = listOf(input)
    override val outputSignals: List<Wire> = listOf(wire)
}

class And(
    override val inputSignals: List<Signal>,
) : ExpressionLogicUnit(){
    val width = inputSignals.first().width
    init {
        assert(inputSignals.isNotEmpty()){ }
        assert(inputSignals.all { it.width == width }){ }
    }

    override val wire by lazy {
        ExpressionWire("And", width, this)
    }

    override val outputSignals: List<Wire> = listOf(wire)
}

class Or(
    override val inputSignals: List<Signal>,
) : ExpressionLogicUnit(){
    val width = inputSignals.first().width
    init {
        assert(inputSignals.isNotEmpty()){ }
        assert(inputSignals.all { it.width == width }){ }
    }

    override val wire by lazy {
        ExpressionWire("Or", width, this)
    }
}

class SAnd(
    val input: Signal,
) : ExpressionLogicUnit(){
    val width = 1

    override val wire by lazy {
        ExpressionWire("SAnd", width, this)
    }

    override val outputSignals: List<Wire> = listOf(wire)
    override val inputSignals: List<Signal> = listOf(input)
}

class SOr(
    val input: Signal,
) : ExpressionLogicUnit(){
    val width = 1

    override val wire by lazy {
        ExpressionWire("SOr", width, this)
    }

    override val outputSignals: List<Wire> = listOf(wire)
    override val inputSignals: List<Signal> = listOf(input)
}

class SXor(
    val input: Signal,
) : ExpressionLogicUnit(){
    val width = 1

    override val wire by lazy {
        ExpressionWire("SXor", width, this)
    }

    override val outputSignals: List<Wire> = listOf(wire)
    override val inputSignals: List<Signal> = listOf(input)
}

class Equal(val inputs:Pair<Signal, Signal>) : ExpressionLogicUnit(){
    init {
        assert(inputs.first.width == inputs.second.width){ }
    }

    override val wire by lazy {
        ExpressionWire("Or", 1, this)
    }
    override val inputSignals: List<Signal> = listOf(inputs.first,inputs.second)
}

class MoreThan(val inputs:Pair<Signal, Signal>) : ExpressionLogicUnit(){
    init {
        assert(inputs.first.width == inputs.second.width){ }
    }

    override val wire by lazy {
        ExpressionWire("mt", 1, this)
    }
    override val inputSignals: List<Signal> = listOf(inputs.first,inputs.second)
}

class LessThan(val inputs:Pair<Signal, Signal>) : ExpressionLogicUnit(){
    init {
        assert(inputs.first.width == inputs.second.width){ }
    }

    override val wire by lazy {
        ExpressionWire("lt", 1, this)
    }
    override val inputSignals: List<Signal> = listOf(inputs.first,inputs.second)
}

class Xor(
    override val inputSignals: List<Signal>,
) : ExpressionLogicUnit(){
    val width = inputSignals.first().width
    init {
        assert(inputSignals.isNotEmpty()){ }
        assert(inputSignals.all { it.width == width }){ }
    }

    override val wire by lazy {
        ExpressionWire("Xor", width, this)
    }
}

class Plus(
    val inputs: Pair<Signal, Signal>,
) : ExpressionLogicUnit(){
    val width = inputs.first.width
    init {
        assert(inputs.first.width == inputs.second.width){ }
    }

    override val wire by lazy {
        ExpressionWire("plus", width, this)
    }
    override val inputSignals: List<Signal> = listOf(inputs.first, inputs.second)
}

/**
 * 对输入信号进行切片
 * @param input 输入信号
 * @param start 起始位，在KaguyaHDL中均为右起，必须大于0
 * @param sliceWidth 切片长度，实际的切片部分为起始位向左指定长度的部分（包含起始位）
 * 示例：[start]为2，[sliceWidth]为5，则切片为input[6:2]
 */
class Slice(
    val input: NamedSignal,
    val start:Int,
    val sliceWidth:Int
): ExpressionLogicUnit(){
    override val wire by lazy {
        ExpressionWire("Slice(${input},from = $start, width = ${sliceWidth})",sliceWidth, this)
    }

    override val inputSignals: List<Signal> = listOf(input)
    init {
        assert(start >= 0)
        assert(start + sliceWidth <= input.width){ }
    }
}

class Repeat(
    val input: Signal,
    val repeat:Int
): ExpressionLogicUnit(){
    override val wire by lazy {
        ExpressionWire("Repeat",input.width * repeat, this)
    }

    override val inputSignals: List<Signal> = listOf(input)
}

class Joint(
    private val inputs:List<Signal>
): ExpressionLogicUnit(){
    override val wire by lazy {
        ExpressionWire("Joint",inputs.map { it.width }.reduce { w1, w2 -> w1+w2 }, this)
    }

    override val inputSignals: List<Signal> = inputs.flatMap {
        if(it is ExpressionWire && it.logicUnit is Joint){
            it.logicUnit.inputSignals
        } else {
            listOf(it)
        }
    }
}